-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[JTS] Propagate profile info #153305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JTS] Propagate profile info #153305
Conversation
4233ab6
to
5b08108
Compare
6f7ecae
to
4be99fc
Compare
5b08108
to
6a998c8
Compare
6a998c8
to
94c57f3
Compare
@llvm/pr-subscribers-pgo @llvm/pr-subscribers-llvm-transforms Author: Mircea Trofin (mtrofin) ChangesIf the indirect call target being recognized as a jump table has profile info, we can accurately synthesize the branch weights of the switch that replaces the indirect call. Otherwise we insert the "unknown" Part of Issue #147390 Full diff: https://github.com/llvm/llvm-project/pull/153305.diff 3 Files Affected:
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index bab1963dba22e..85a9efe73855b 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -665,6 +665,10 @@ class InstrProfSymtab {
return Error::success();
}
+ const std::vector<std::pair<uint64_t, Function *>> &getIDToNameMap() const {
+ return MD5FuncMap;
+ }
+
const StringSet<> &getVTableNames() const { return VTableNames; }
/// Map a function address to its name's MD5 hash. This interface
diff --git a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
index 7f99cd2060a9d..6719ce64b96b6 100644
--- a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
@@ -7,14 +7,24 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <limits>
using namespace llvm;
@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
"or equal than this threshold."),
cl::init(50));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
#define DEBUG_TYPE "jump-table-to-switch"
namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
return JumpTable;
}
-static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
- DomTreeUpdater &DTU,
- OptimizationRemarkEmitter &ORE) {
+static BasicBlock *
+expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
+ OptimizationRemarkEmitter &ORE,
+ llvm::function_ref<GlobalValue::GUID(const Function &)>
+ GetGuidForFunction) {
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
IRBuilder<> BuilderTail(CB);
PHINode *PHI =
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
-
+ const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
+
+ SmallVector<uint64_t> BranchWeights;
+ DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
+ const bool HadProfile = isValueProfileMD(ProfMD);
+ if (HadProfile) {
+ // The assumptions, coming in, are that the functions in JT.Funcs are
+ // defined in this module (from parseJumpTable).
+ assert(llvm::all_of(
+ JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
+ BranchWeights.reserve(JT.Funcs.size() + 1);
+ // The first is the default target, which is the unreachable block created
+ // above.
+ BranchWeights.push_back(0U);
+ uint64_t TotalCount = 0;
+ auto Targets = getValueProfDataFromInst(
+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
+ std::numeric_limits<uint32_t>::max(), TotalCount);
+
+ for (const auto &[G, C] : Targets) {
+ auto It = GuidToCounter.insert({G, C});
+ assert(It.second);
+ (void)It;
+ }
+ }
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
BasicBlock *B = BasicBlock::Create(Func->getContext(),
"call." + Twine(Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
Call->insertInto(B, B->end());
Switch->addCase(
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
+ GlobalValue::GUID FctID = GetGuidForFunction(*Func);
+ // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
+ // just some of the jump targets are taken (for the given profile).
+ BranchWeights.push_back(FctID == 0U ? 0U
+ : GuidToCounter.lookup_or(FctID, 0U));
BranchInst::Create(Tail, B);
if (PHI)
PHI->addIncoming(Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
<< "expanded indirect call into switch";
});
+ if (HadProfile && !ProfcheckDisableMetadataFixes) {
+ // At least one of the targets must've been taken.
+ assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
+ setProfMetadata(F.getParent(), Switch, BranchWeights,
+ *llvm::max_element(BranchWeights));
+ } else
+ setExplicitlyUnknownBranchWeights(*Switch);
if (PHI)
CB->replaceAllUsesWith(PHI);
CB->eraseFromParent();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
bool Changed = false;
+ InstrProfSymtab Symtab;
+ if (auto E = Symtab.create(*F.getParent()))
+ F.getContext().emitError(
+ "Could not create indirect call table, likely corrupted IR" +
+ toString(std::move(E)));
+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
+ FToGuid.insert({FPtr, G});
+
for (BasicBlock &BB : make_early_inc_range(F)) {
BasicBlock *CurrentBB = &BB;
while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
if (!JumpTable)
continue;
- SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
+ SplittedOutTail = expandToSwitch(
+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
+ if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
+ return AssignGUIDPass::getGUID(Fct);
+ return FToGuid.lookup_or(&Fct, 0U);
+ });
Changed = true;
break;
}
diff --git a/llvm/test/Transforms/JumpTableToSwitch/basic.ll b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
index 321f837077ab6..577c2adaf5afa 100644
--- a/llvm/test/Transforms/JumpTableToSwitch/basic.ll
+++ b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
@@ -4,11 +4,11 @@
@func_array = constant [2 x ptr] [ptr @func0, ptr @func1]
-define i32 @func0() {
+define i32 @func0() !guid !0 {
ret i32 1
}
-define i32 @func1() {
+define i32 @func1() !guid !1 {
ret i32 2
}
@@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) {
;
%gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index
%func_ptr = load ptr, ptr %gep
- %result = call i32 %func_ptr()
+ %result = call i32 %func_ptr(), !prof !2
ret i32 %result
}
@@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) {
ret i32 %result
}
+!0 = !{i64 5678}
+!1 = !{i64 5555}
+!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5}
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
std::numeric_limits<uint32_t>::max(), TotalCount); | ||
|
||
for (const auto &[G, C] : Targets) { | ||
auto It = GuidToCounter.insert({G, C}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[[maybe_unused]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! so much more elegant than (void)It.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh. saw this after I pushed merge, thought I canceled, then... well, race condition. Fixing in follow-up #153639
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/18673 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/18650 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/19861 Here is the relevant piece of the build log for the reference
|
Fixing buildbot failures after PR #153305, e.g. https://lab.llvm.org/buildbot/#/builders/203/builds/19861 Analysis already depends on `ProfileData`, so the transitive closure of the dependencies of `ScalarOpts` doesn't change. Also avoided an extra dependency (and very unnecessary) on `Instrumentation`. The API previously used doesn't need to live in Instrumentation to begin with, but that's something to address in a follow-up.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/160/builds/22976 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/180/builds/23121 Here is the relevant piece of the build log for the reference
|
If the indirect call target being recognized as a jump table has profile info, we can accurately synthesize the branch weights of the switch that replaces the indirect call.
Otherwise we insert the "unknown"
MD_prof
to indicate this is the best we can do here.Part of Issue #147390